{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BNN on Pynq\n",
    "\n",
    "This notebook covers how to use Binary Neural Networks on Pynq. \n",
    "It shows an example of handwritten digit recognition using a binarized neural network composed of 4 fully connected layers with 1024 neurons each, trained on the MNIST dataset of handwritten digits. \n",
    "In order to reproduce this notebook, you will need an external USB Camera connected to the PYNQ Board.\n",
    "\n",
    "## 1. Import the package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "\n",
       "require(['notebook/js/codecell'], function(codecell) {\n",
       "  codecell.CodeCell.options_default.highlight_modes[\n",
       "      'magic_text/x-csrc'] = {'reg':[/^%%microblaze/]};\n",
       "  Jupyter.notebook.events.one('kernel_ready.Kernel', function(){\n",
       "      Jupyter.notebook.get_cells().map(function(cell){\n",
       "          if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n",
       "  });\n",
       "});\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import bnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Checking available parameters\n",
    "\n",
    "By default the following trained parameters are available for LFC network using 1 bit for weights and 1 threshold for activation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['chars_merged', 'mnist']\n"
     ]
    }
   ],
   "source": [
    "print(bnn.available_params(bnn.NETWORK_LFCW1A1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two sets of weights are available for the LFCW1A1 network, the MNIST and one for character recognition (NIST)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Instantiate the classifier\n",
    "\n",
    "Creating a classifier will automatically download the correct bitstream onto the device and load the weights trained on the specified dataset. This example works with the LFCW1A1 for inferring MNIST handwritten digits.\n",
    "Passing a runtime attribute will allow to choose between hardware accelerated or pure software inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "hw_classifier = bnn.LfcClassifier(bnn.NETWORK_LFCW1A1,\"mnist\",bnn.RUNTIME_HW)\n",
    "sw_classifier = bnn.LfcClassifier(bnn.NETWORK_LFCW1A1,\"mnist\",bnn.RUNTIME_SW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n"
     ]
    }
   ],
   "source": [
    "print(hw_classifier.classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Load the image from the camera\n",
    "The image is captured from the external USB camera and stored locally. The image is then enhanced in contract and brightness to remove background noise. \n",
    "The resulting image should show the digit on a white background:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAHACAAAAADxE48PAAAOVUlEQVR4nO2dPWgcSRqGq48FyaDMNijzCJStBQ4WbMNmtmCDA0twgWHleCXYYMFWdnAybGA4KTNYgg0OLMEGCytnhpWyhZMyw1qZYUeZwRJcYFhN1Bdo+m+6+qe6q1Vd3/s+yVT39AzfVD3zVfVfdRAqgszfXAdA3EIBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBcgxch3AlfKF6wB6xeGH2ZcHSi3dmZ2ZnZmdUer6aEqNppQ6vz6ach1cNwSh6wh6wWhq/T9nVRttzj+6iliuFgqgDj+umH1g83slJx1QgJ1V448Mvn7dQSBuABdgbbvxR4VUHPBewOnzIGje/ioIzu3F4g7cDHDy+H3Lbxj8fNdKJE4BFeD0h307X+R99SEKMPr9ob0v873+EAVYPLD5bZ5XIN4g8DAobf97B6GGi+JDBZ4PBsEywPG9oneGt+p+Rz6B7H7bNB73gAkwNyx4Y9/gKO/Ji92JNR5XIpQATyYbTiml1OCnBw2+K8gs+VuLSAIE+VWrP15v+m0nt9NL3lYj0CDweX7V7qvG7a++zKSTxcbf4xicDJD7/581b3zdV3pakSgZ4Plk+98LW7e/CsPpZEHTwfgASgaYbJ6lX+1/r5dVCZIB3mQXV0JL7Z9pdC9zAIgAS5ml9xav5whTh5Y0w8zeA9IFdDlcW96Pi6uv7H71FQAogPVfnPpy/2oTpAsYxqUV+20U3oiLgXc3FYAIcCvaX9vs4nLOT0nxYwdf3ykgXUDHPN+Ii77VJwWwwl5yvYBnFUoBrHCeDAM8q1AKYIdRclDYrxoFGQR2zlTS7DsOwzCHGcAW8dGA6b9chmEKM4At4n/ShVfXBjADWCM5IOhTnTIDWCNp9mWHUZhCAeyxERX23cVgDLsAi8SdgEeVygzQBR71AcwANvEwBTAD2KRynqn+QQFsEp8MXnMZhRHsAqziXx/ADGCVoesAjKEAVonvMT91GYUJFKAb/u46gLpwDGCXeAICX+qVGcAuP0aFQ5dRGMAMYBnf9gOYAcChAOBQAHAoADgUABwKAA4FsEzhVKQ9hQJYZj4qeHIkiAJY5k5UeOwwCAN4JNAyp4Oo5EfNUgDbeHYsmF1AV0xXb9IHKEBXzFdv0gcoQFfMug6gHhSgK+64DqAeFKArPOkCuBdgm2gvwJOKZQYAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDASzjzcQAYyiAZX4Zv94o3ao/UADLvBy//uQ0ivrwbKBlPDsZyAxgG94YAs7R5cuS0yAMoADd8A/XAdSFYwDLjMcA3lQrMwA4FAAcCgAOBbDLuusATKEAdnnnOgBTuBdgF98OBDIDdMOm6wBqQwGsEg0BnjqNwgR2AVa5dnH56k+tUgCreDcEYBeADgXogmeuA6gPuwCb7K1cvnpUqcwANnlZvUnfoAA2+ew6AHMogEUW3ruOwBwKYJGo/R86jcIMDgLtcR7dC+BTnTID2OOfrgNoAjOAPTybJfgSZgD7DF0HYAIFsMba+HXpVulmPYMCWGNm/Pqz0yhMoQDWeDt+nXIahSkcBFrDv1PBSjED2GNr/OrRmUClmAHsESWAgwdOwzCFAlji5tm44FmFsguww1bU/ttOwzCHGcAOXh4FVIoZwBInrgNoDDOAFeIEwAwAzoHrAEyhADZIHhTt1z6gYhdgB397AGYAu/h3USAFsMBWXPrSYRTNoAAWiG8HWHEZRTM4BrCAt0eBFDOADRajwr7DIJrCDNCaw/g2AB/rkhmgNT9EhQ13MTSHGaAt94+ikpdVyQzQlrj9fbohLIECtCQ5BvCbwyiaQwFaMlO9Sa/hGKAlPh8DUIoZoC3HrgNoCwVox4uosO8wiDawC2hH1AOsvnIaRnMoQCtObo8L3lYjBWiF70NAjgHa8SYq7DsMoh0UoAWn76LSI4dRtINdQAs8vhQwhhmgOedxyd/2pwAt+Mp1ADagAI3ZG0Ylz6YEyMAxQGMkjACYAWzg54UAY5gBmuL5lUARX7gOwFfiGUH8ux80AzNAQ+IRwPRfLsNoDccAzUieEex3+7MLaMbcMCr5nkHZBTQhfjKA/wKwC2hC0v7+3Q8+ATNAA5IOwPsEwAzQgOVhXDwq3soTmAHMSY4B+58AmAGMOV2Oiw/9b39mAGNS///3/s0Ik4MCGJIaAAroACiAKaPppCyi6jgGMCPV/hvOgrAJM4AR91P7fTJqjhnAhGVx7c8MYISwPQClKIAJp4PUgpR6YxdQn0FS3JXS/swA9UnlfzH/f2aA+sylynLanwLU5XiYlP2+DDQLu4CaXLtIypLqjBmgHutC258C1GN9My7eENX+7ALqIeoakAzMAHVItf+quyg6gRmgBqn2v/dfd2F0AgWoQUqAs+vuwugEdgHVpNp/W1r7MwNUI/MQcAQzQBVCDwFHUIBy1oJhsuDhYwErYRdQSvoaUDW85SyO7mAGKOVxekFi+3N+gFLSZ4DUReFmPsMMUMx5kGn/KWeBdAkFKOZlZklm+7MLKGZ5P70kdbDMDFDEzn5q4UBq+3M3sIjMNeA3PrkKo3OwBNh5Oz//4cPM7OeZ79+VP+PhMDP/q+BKAhFgpL6aUbPzn7dT655986Bo82zzS25/EAFOH8//Psyv3v22YPNBZlF0FSEIsPeiaDI3/eUda9uZRaEHAMbIF+DJL2WH8DQKLGYv+xdeQeIFyJzO0ZGrgKD8bWFIPw6wVtX+aiG7OIfV/sIzwMRoroB0FRzfy7wl7hLAHLIFCKo3UdlRXub8n7xrgPNI7gIW67V/euKnzPk/9V5++0s+GbRX+ybeOA1mjJGf/pWS3AUsmMzkflkLYOM/pZTgLmDdaCb/a0qpJ5n2lzAPcB2kZoAt06d5rrzOtv9vFoPpM1IF0I//Np+q9APfShBaLXmEdgFvNOt2w/CpUkp9CsOwwoEBTPtLzQC5BFBxwDfLymurwfQamQJMtq7u4Y4lBoiskwJEdgE72cWzUPdwx8KZHoTNAVOBxAxQd3c+/eyHBKT8rwAEKPt9ul4A4PB/BoFdwGKqvFrqt+5NsPaXmAEMZvTSpICDwitFZSIwAyQ0mdHroe4QgmDkCbAXlzZfVWy6plu5ZDEWD5DXBSRpveqnLRacL958ai+a3iMvA8RUXA6wHhRtYHoeyWvEXRByGhV2S0dzqcl/8wRAxwLFdQFxXi/7Yec3qr4G5nSwuC7g6xrbLFS2v6hnQpQiToA7lVvcD+pcLBTUvKLUd8R1AfGdQPofdvNz/bmepFWNFnEZoPxOzuUzTfuH4VF+pVLq0EY8fUdcBoiPA2h+WHbWnzGXwz19whdXOXnkCjBxEUhRl669JSD3rlzEdQExF1vppVHBPL9hrgCGPAHi43jPkkP95zend7XbpppdexvxnG6lKOR1Aelkvvvim48zb4dFW+5nJ4rS9QKC5we7RLYA5Uz+9ImpofQbSQNVgOn/6fYXNR8d/NkynJ4DKUDxxP+6z8qroTTyBoFqo2qDsHjif92Z4OctYuk/AjNARQqo+MFolwkKzABqWPJe5azPYX4k+FByDpCYASZnekxR564PTQ6QWEljRApQ0Akc3W3+aZHVpJTMLkDfXBthzfbXflrs5QEyM4DmX2w04y/QPWMyM0D+XxwazfgcbuTXHcnMAVIzgMr8jRtc6a9tboGVJTUDqLi1NsPx1DBmLOlWXmsRTk8RnAGUOnnb4h4f7akheTlAtAAt0fYCRU8Z8RXBXUBrtP+NlR3dWn+hACXodgbUqqy9AXYB5YjvBpgBytHeQ7qypVvrJ8wAFTzRXk0qZy4xClBJQZ8vpOLYBVRS0NJCxoIUoJoiA7RTDPkGu4A6FPzbJVwrRgFqUXVnob+wC6hFGCqlbufXB97vEDID1CVQSk2r/PQCnlcgM0BdQqXUhWZ6iWAhv84jmAHqU/wcao8rkRmgPlMFU8koFQR7Be/0Hgpgwt0NpZRSg/w7K8HJ5KpRx8HYgQIY8a9wVRXcenQ7SCaVOgyCIJgOgrnlK4qrORwDmGN0ELjv9csMYE44qL9t4bixLzADNOJ0UHfLvtcvM0AjboVhzUnlj7sNpDUUoCn/rjef9Ltuo2gNBWjMg7DOtMOzncfRDgrQgqkwLHvuhFJKqUdVGziGg8CWXKtIA32vXwrQmtLDAr2vXnYBrQnDwlbe7n37UwArhNp7iA7C7646EHPYBdhkIXkYjS8TijAD2OSPy+nJzsIw9KT9KYBdvlNKqY3rrsMwgF0AOMwA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHAoADgUABwKAA4FAIcCgEMBwKEA4FAAcCgAOBQAHAoADgUAhwKAQwHAoQDgUABwKAA4FAAcCgAOBQCHAoBDAcChAOBQAHD+DzWsiFK66S4XAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=512x448 at 0x7F4D6BCE80>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import cv2\n",
    "from PIL import Image as PIL_Image\n",
    "from PIL import ImageEnhance\n",
    "from PIL import ImageOps\n",
    "\n",
    "# says we capture an image from a webcam\n",
    "cap = cv2.VideoCapture(0) \n",
    "_ , cv2_im = cap.read()\n",
    "cv2_im = cv2.cvtColor(cv2_im,cv2.COLOR_BGR2RGB)\n",
    "img = PIL_Image.fromarray(cv2_im).convert(\"L\") \n",
    "\n",
    "#original captured image\n",
    "#orig_img_path = '/home/xilinx/jupyter_notebooks/bnn/pictures/webcam_image_mnist.jpg'\n",
    "#img = PIL_Image.open(orig_img_path).convert(\"L\")     \n",
    "                   \n",
    "#Image enhancement                \n",
    "contr = ImageEnhance.Contrast(img)\n",
    "img = contr.enhance(3)                                                    # The enhancement values (contrast and brightness) \n",
    "bright = ImageEnhance.Brightness(img)                                     # depends on backgroud, external lights etc\n",
    "img = bright.enhance(4.0)          \n",
    "\n",
    "#img = img.rotate(180)                                                     # Rotate the image (depending on camera orientation)\n",
    "#Adding a border for future cropping\n",
    "img = ImageOps.expand(img,border=80,fill='white') \n",
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Crop and scale the image\n",
    "The center of mass of the image is evaluated to properly crop the image and extract the written digit only. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image as PIL_Image\n",
    "import numpy as np\n",
    "import math\n",
    "from scipy import misc\n",
    "\n",
    "#Find bounding box  \n",
    "inverted = ImageOps.invert(img)  \n",
    "box = inverted.getbbox()  \n",
    "img_new = img.crop(box)  \n",
    "width, height = img_new.size  \n",
    "ratio = min((28./height), (28./width))  \n",
    "background = PIL_Image.new('RGB', (28,28), (255,255,255))  \n",
    "if(height == width):  \n",
    "    img_new = img_new.resize((28,28))  \n",
    "elif(height>width):  \n",
    "    img_new = img_new.resize((int(width*ratio),28))  \n",
    "    background.paste(img_new, (int((28-img_new.size[0])/2),int((28-img_new.size[1])/2)))  \n",
    "else:  \n",
    "    img_new = img_new.resize((28, int(height*ratio)))  \n",
    "    background.paste(img_new, (int((28-img_new.size[0])/2),int((28-img_new.size[1])/2)))  \n",
    "  \n",
    "background  \n",
    "img_data=np.asarray(background)  \n",
    "img_data = img_data[:,:,0]  \n",
    "misc.imsave('/home/xilinx/img_webcam_mnist.png', img_data) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Convert to BNN input format\n",
    "The image is resized to comply with the MNIST standard. The image is resized at 28x28 pixels and the colors inverted. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAAcElEQVR4nNVS2w2AMAgE4wYO4NgO4ACO5AAOcX60goSz2k/vp+ReIaQiDgDAKgxAfZlGJtpxDUPL9S+R4GiJfqHO2ttpU3L6GOzTZkrvgP+GCo11KgRYJOXMCi3ZmBy9Vh9XLfyWmqPpRYXVk91h7Ak1jCnedgrNYwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28 at 0x7F4C403860>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from array import *\n",
    "from PIL import Image as PIL_Image\n",
    "from PIL import ImageOps\n",
    "img_load = PIL_Image.open('/home/xilinx/img_webcam_mnist.png').convert(\"L\")  \n",
    "# Convert to BNN input format  \n",
    "# The image is resized to comply with the MNIST standard. The image is resized at 28x28 pixels and the colors inverted.   \n",
    "  \n",
    "#Resize the image and invert it (white on black)  \n",
    "smallimg = ImageOps.invert(img_load)  \n",
    "smallimg = smallimg.rotate(0)  \n",
    "  \n",
    "data_image = array('B')  \n",
    "  \n",
    "pixel = smallimg.load()  \n",
    "for x in range(0,28):  \n",
    "    for y in range(0,28):  \n",
    "        if(pixel[y,x] == 255):  \n",
    "            data_image.append(255)  \n",
    "        else:  \n",
    "            data_image.append(1)  \n",
    "          \n",
    "# Setting up the header of the MNIST format file - Required as the hardware is designed for MNIST dataset         \n",
    "hexval = \"{0:#0{1}x}\".format(1,6)  \n",
    "header = array('B')  \n",
    "header.extend([0,0,8,1,0,0])  \n",
    "header.append(int('0x'+hexval[2:][:2],16))  \n",
    "header.append(int('0x'+hexval[2:][2:],16))  \n",
    "header.extend([0,0,0,28,0,0,0,28])  \n",
    "header[3] = 3 # Changing MSB for image data (0x00000803)  \n",
    "data_image = header + data_image  \n",
    "output_file = open('/home/xilinx/img_webcam_mnist_processed', 'wb')  \n",
    "data_image.tofile(output_file)  \n",
    "output_file.close()   \n",
    "smallimg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Launching BNN in hardware\n",
    "\n",
    "The image is passed in the PL and the inference is performed. Use `classify_mnist` to classify a single mnist formatted picture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inference took 8.00 microseconds\n",
      "Classification rate: 125000.00 images per second\n",
      "Class number: 2\n",
      "Class name: 2\n"
     ]
    }
   ],
   "source": [
    "class_out = hw_classifier.classify_mnist(\"/home/xilinx/img_webcam_mnist_processed\")\n",
    "print(\"Class number: {0}\".format(class_out))\n",
    "print(\"Class name: {0}\".format(hw_classifier.class_name(class_out)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Launching BNN in software\n",
    "The inference on the same image is performed in sofware on the ARM core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inference took 18588.00 microseconds\n",
      "Classification rate: 53.80 images per second\n",
      "Class number: 2\n",
      "Class name: 2\n"
     ]
    }
   ],
   "source": [
    "class_out=sw_classifier.classify_mnist(\"/home/xilinx/img_webcam_mnist_processed\")\n",
    "print(\"Class number: {0}\".format(class_out))\n",
    "print(\"Class name: {0}\".format(hw_classifier.class_name(class_out)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 9. Reset the device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pynq import Xlnk\n",
    "\n",
    "xlnk = Xlnk()\n",
    "xlnk.xlnk_reset()"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
